{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1969184"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "3*331*32+3*3*32*32+3*3*32*64+3*3*64*64+3136*512+512*512+512*10\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7fcc38058790>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n_epochs = 3\n",
    "batch_size_train = 64\n",
    "batch_size_test = 1000\n",
    "learning_rate = 0.01\n",
    "momentum = 0.5\n",
    "log_interval = 10\n",
    "random_seed = 1\n",
    "torch.manual_seed(random_seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(\n",
    "  torchvision.datasets.MNIST('./data/', train=True, download=True,\n",
    "                             transform=torchvision.transforms.Compose([\n",
    "                               torchvision.transforms.ToTensor(),\n",
    "                               torchvision.transforms.Normalize(\n",
    "                                 (0.1307,), (0.3081,))\n",
    "                             ])),\n",
    "  batch_size=batch_size_train, shuffle=True)\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "  torchvision.datasets.MNIST('./data/', train=False, download=True,\n",
    "                             transform=torchvision.transforms.Compose([\n",
    "                               torchvision.transforms.ToTensor(),\n",
    "                               torchvision.transforms.Normalize(\n",
    "                                 (0.1307,), (0.3081,))\n",
    "                             ])),\n",
    "  batch_size=batch_size_test, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([4, 4, 2, 1, 8, 8, 2, 8, 3, 8, 5, 5, 5, 3, 0, 3, 6, 3, 0, 7, 0, 7, 2, 3,\n",
      "        3, 4, 8, 2, 6, 1, 1, 2, 9, 5, 8, 2, 0, 3, 2, 3, 9, 9, 9, 3, 7, 9, 5, 7,\n",
      "        2, 1, 6, 7, 2, 9, 3, 1, 3, 2, 0, 4, 8, 7, 5, 2, 4, 6, 5, 3, 9, 8, 9, 6,\n",
      "        2, 8, 3, 9, 5, 9, 5, 0, 0, 4, 0, 2, 3, 0, 2, 8, 5, 8, 2, 2, 1, 3, 1, 5,\n",
      "        1, 7, 2, 8, 8, 4, 2, 2, 9, 3, 1, 1, 4, 7, 2, 8, 7, 6, 2, 7, 6, 3, 6, 1,\n",
      "        9, 0, 9, 7, 7, 9, 4, 0, 9, 4, 0, 2, 0, 7, 5, 9, 0, 4, 6, 7, 9, 4, 2, 4,\n",
      "        0, 1, 2, 1, 4, 9, 3, 5, 3, 7, 8, 8, 3, 8, 5, 9, 1, 8, 3, 6, 8, 3, 3, 8,\n",
      "        1, 6, 7, 4, 3, 3, 5, 6, 9, 6, 0, 2, 5, 1, 8, 0, 1, 4, 1, 1, 1, 7, 2, 7,\n",
      "        1, 1, 0, 4, 7, 2, 4, 6, 1, 4, 3, 6, 2, 7, 5, 4, 1, 3, 3, 3, 4, 7, 8, 0,\n",
      "        4, 3, 5, 8, 9, 4, 1, 9, 4, 7, 7, 1, 9, 6, 0, 4, 2, 9, 2, 1, 3, 6, 6, 0,\n",
      "        2, 2, 0, 0, 2, 5, 8, 9, 7, 0, 0, 9, 2, 0, 9, 2, 6, 9, 9, 5, 4, 3, 2, 9,\n",
      "        9, 4, 3, 5, 1, 3, 8, 5, 8, 4, 9, 5, 1, 4, 4, 6, 0, 7, 6, 8, 8, 8, 0, 9,\n",
      "        8, 4, 5, 9, 4, 6, 3, 9, 9, 4, 5, 6, 1, 1, 8, 6, 9, 7, 1, 3, 6, 0, 1, 9,\n",
      "        0, 6, 1, 6, 2, 8, 3, 1, 9, 0, 3, 0, 8, 1, 0, 0, 6, 3, 7, 3, 1, 0, 5, 0,\n",
      "        7, 1, 5, 5, 5, 2, 5, 1, 4, 3, 4, 8, 5, 8, 4, 6, 3, 2, 3, 9, 7, 0, 9, 0,\n",
      "        6, 1, 8, 4, 8, 8, 8, 1, 3, 5, 7, 1, 3, 1, 6, 0, 8, 7, 7, 3, 4, 2, 9, 6,\n",
      "        8, 1, 2, 4, 4, 7, 2, 6, 1, 5, 4, 8, 8, 1, 7, 2, 7, 7, 5, 7, 3, 0, 9, 2,\n",
      "        3, 7, 4, 6, 6, 5, 3, 1, 7, 3, 7, 3, 3, 9, 0, 7, 3, 4, 3, 6, 5, 8, 0, 7,\n",
      "        9, 6, 9, 9, 3, 4, 5, 8, 2, 6, 2, 1, 3, 1, 1, 3, 9, 2, 6, 4, 0, 7, 4, 8,\n",
      "        1, 5, 8, 2, 1, 7, 2, 7, 7, 8, 9, 5, 6, 5, 9, 3, 5, 4, 7, 6, 6, 6, 2, 0,\n",
      "        3, 3, 7, 0, 5, 7, 5, 2, 7, 1, 7, 2, 7, 2, 5, 9, 1, 0, 3, 5, 2, 5, 8, 1,\n",
      "        6, 9, 1, 7, 1, 1, 0, 5, 8, 9, 6, 3, 3, 4, 5, 7, 3, 0, 1, 6, 5, 8, 8, 1,\n",
      "        3, 3, 4, 7, 0, 0, 3, 4, 3, 4, 1, 1, 2, 5, 3, 1, 5, 2, 5, 8, 0, 2, 8, 8,\n",
      "        7, 3, 7, 4, 3, 1, 8, 0, 4, 9, 9, 1, 0, 6, 5, 8, 7, 0, 2, 1, 2, 4, 2, 3,\n",
      "        9, 1, 0, 6, 6, 6, 6, 4, 2, 5, 2, 0, 4, 8, 4, 2, 0, 0, 7, 0, 4, 2, 7, 5,\n",
      "        7, 8, 2, 1, 6, 6, 7, 9, 2, 9, 9, 9, 0, 9, 0, 6, 9, 8, 1, 3, 1, 4, 8, 8,\n",
      "        7, 1, 0, 8, 2, 0, 7, 9, 4, 4, 5, 7, 6, 8, 8, 9, 3, 4, 9, 7, 8, 9, 7, 7,\n",
      "        1, 1, 2, 0, 7, 8, 5, 7, 3, 4, 7, 7, 1, 7, 6, 4, 5, 3, 6, 3, 0, 8, 2, 2,\n",
      "        1, 7, 7, 2, 1, 7, 5, 4, 2, 4, 2, 1, 2, 7, 9, 5, 7, 3, 7, 1, 9, 0, 0, 5,\n",
      "        6, 1, 9, 4, 2, 6, 7, 3, 5, 5, 0, 1, 2, 7, 1, 4, 5, 7, 1, 5, 4, 5, 7, 4,\n",
      "        7, 7, 1, 0, 9, 1, 6, 7, 5, 3, 0, 3, 3, 8, 2, 2, 7, 2, 9, 3, 0, 2, 2, 6,\n",
      "        0, 2, 4, 2, 0, 1, 2, 7, 6, 7, 1, 7, 9, 6, 4, 5, 1, 6, 8, 4, 0, 5, 2, 7,\n",
      "        5, 4, 3, 3, 4, 6, 8, 5, 7, 6, 7, 0, 4, 2, 3, 5, 7, 6, 6, 8, 2, 1, 3, 4,\n",
      "        8, 5, 9, 2, 8, 9, 4, 5, 9, 7, 0, 7, 2, 3, 0, 4, 6, 6, 6, 2, 9, 7, 4, 2,\n",
      "        7, 4, 9, 7, 7, 5, 0, 1, 2, 8, 2, 2, 3, 1, 7, 9, 6, 9, 5, 2, 0, 1, 8, 9,\n",
      "        1, 5, 6, 9, 9, 1, 0, 4, 5, 0, 9, 8, 9, 4, 7, 8, 2, 2, 6, 6, 7, 2, 6, 3,\n",
      "        7, 1, 0, 1, 2, 2, 9, 0, 2, 7, 6, 2, 3, 1, 7, 7, 1, 9, 0, 9, 7, 1, 4, 4,\n",
      "        5, 8, 2, 8, 1, 4, 6, 4, 3, 9, 9, 2, 8, 2, 3, 6, 6, 5, 1, 1, 2, 7, 4, 9,\n",
      "        5, 7, 7, 8, 1, 6, 2, 9, 0, 4, 1, 0, 6, 1, 1, 1, 1, 2, 9, 9, 5, 3, 1, 2,\n",
      "        1, 8, 3, 4, 5, 1, 9, 3, 6, 1, 6, 1, 9, 9, 8, 9, 9, 1, 6, 1, 4, 7, 5, 1,\n",
      "        8, 5, 6, 7, 5, 9, 0, 3, 2, 0, 2, 4, 6, 5, 3, 1, 6, 5, 4, 5, 0, 8, 3, 1,\n",
      "        9, 8, 7, 7, 3, 6, 5, 4, 9, 5, 3, 9, 1, 4, 9, 7])\n",
      "torch.Size([1000, 1, 28, 28])\n"
     ]
    }
   ],
   "source": [
    "examples = enumerate(test_loader)\n",
    "batch_idx, (example_data, example_targets) = next(examples)\n",
    "print(example_targets)\n",
    "print(example_data.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZQAAAELCAYAAAD+9XA2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAbFklEQVR4nO3de7hVVbnH8d8LCl7YgSmoIIgcS4UjQt7wkpohJYiXNuAtPaJpZqmJpPmox0LMC6iIgaZppnIQxAtKppaXfAAhtLQDZpgeBfGKggoiKHucP9ZiNsdsr7XXmmusy158P8+zn2e8jDnHHJs94N1zjrHGNOecAAAoVZtqdwAAUB9IKACAIEgoAIAgSCgAgCBIKACAIEgoAIAg6jqhmFlPM3NmtkkVrv26mQ2s9HURBmMHaW3MY6fkhGJmx5nZfDNbbWbvZctnmZmF6GC5mNmq2FeTma2JxScW2dYdZja2TP28PTs4dy5H+9XE2GHspMXYCT92zGyImc02s5Vm9o6Z/drMGoppo6SEYmbnS7pB0jhJ20naVtKZkg6Q1C7HOW1LuWYozrkOG74kLZE0NPZnUzYcV43fMmLXPlDSf1Tr+uXE2Ckvxs6/ncPYaVlHSWMldZW0m6RuyvwdF845l+ore/HVkhpbOO4OSTdJeiR7/MBsZ5+WtFLSIklHxo5/WtL3YvEpkmbHYqfM4Hkle/4kSZataytpvKTlkl6T9MPs8Zu00MfXJQ3Mlg+R9KakCyW9I+muZB9i/dhZ0hmSPpe0TtIqSQ/H2hwt6W+SPpI0TdJmRfz9biLpr5L6brhW2p9VrX0xdhg7jJ3aHDuJa31H0v8Wc04pdyj7SWovaWYBx54g6QpJDZLmS3pY0uOSukg6W9IUM9uliGsfIWlvZf7BjJD0reyfn56t6y9pL0nDimgzbjtJX5a0ozI/uJycc7dImiLpGpf5LWNorHqEpG9L2inb11M2VGRvKw/M0/R5kp5xzv0t1XdQ2xg7YuykxNhRWcdO3EHKJN6ClZJQtpG03Dn3xYY/MLO52Q6vMbODYsfOdM7Ncc41SeonqYOkq5xz65xzT0qaJen4Iq59lXNupXNuiaSnsm1Kmb/ICc65pc65DyVdmfJ7a5J0mXNurXNuTco2JGmic+6tbF8ejvVTzrlOzrnZzZ1kZt0lfV/Sf5dw7VrG2GkZY6d5jJ2WpRo7cWZ2mKT/UpHjqJSE8oGkbeLP+pxz+zvnOmXr4m0vjZW7Slqa/SFv8IYyz+sK9U6s/KkyAyVqO9FuGu875z5LeW5crn62ZIKkMc65jwL0oRYxdlrG2GkeY6dlaceOJMnMBkj6H0nDnHOLizm3lITyrKS1ko4q4Nj4lsZvSepuZvFr95C0LFteLWmLWN12RfTpbUndE+2mkdyC2euTmSX7FHrL5m9KGpddabFhcDxrZicEvk61MHZyH18qxs6/MHaKZGb9JT0k6VTn3BPFnp86oTjnVkr6uaTJZjbMzBrMrI2Z9ZO0ZZ5T5yuTNS8ws03N7BBJQyXdk61/QdJ3zGyL7HLH04ro1nRJ55jZDma2laSfFnFuPi9K6mNm/cxsM0k/S9S/K6lXoGtJ0lcl7aHMrWq/7J8NlfRAwGtUDWPHw9gpAmPHE3TsmNl/SnpU0tnOuYfTtFHSsmHn3DWSRkm6QJlv7l1Jv1JmpcLcHOesU+YHebgyqyImSzrZOfdy9pDrlVm58K6k3yoz8VSoWyU9pswP4i+S7i/uO2pe9rZvjKQ/KrPKI/kM8jZJvbPPcR8spM3suvOv57jee865dzZ8Zf94eYnPVWsKYyfC2CkSYycSdOxIOl9SZ0m3xT4bU9Sk/IZlbwAAlKSut14BAFQOCQUAEAQJBQAQBAkFABAECQUAEERRO1qaGUvCapBzrta37Gbc1KblzrnO1e5EPoydmtXs2OEOBdh4pd0iBGh27JBQAABBkFAAAEGQUAAAQZBQAABBkFAAAEGQUAAAQZBQAABBkFAAAEGQUAAAQZBQAABBkFAAAEGQUAAAQRS12zCAwk2fPt2LBwwYEJV79OhR6e4AZccdCgAgCBIKACCIij3yampqaracdOyxx3rxfffdV7Y+ASF1797di4cPH+7Fo0aNqmR30IqMHDnSi7fccksvXrp0aVSeOXNmRfqUBncoAIAgSCgAgCBIKACAIGpuDgVorYYNG5a3/tlnn61QT9AaxOdNbrnlFq+uTRv/d/0nnngiKjOHAgCoeyQUAEAQfFI+q0+fPlG5pUcXkydPjsrvv/9+2fqE8kgu7+3WrZsXz5s3L1W71113Xd76tO2i9dh0002j8hFHHOHVHXzwwV4c/4hE8hFX0qJFiwL0rvy4QwEABEFCAQAEQUIBAASx0c6hTJs2zYu7du0alffdd9+85+6zzz5ReciQIWE7hrJL/uz3228/LzazgttKzsegviW3RBk0aJAXX3zxxVG5f//+Xl1yXDnnCr5ux44dCz62mrhDAQAEQUIBAARBQgEABLHRzqE0NjZ6cTHbwRx22GGhu4MyO++886Jycs4kvjV4sfJ9ZomtVupDhw4dovLgwYO9uqlTpxbcTilzKEceeWRUvvHGG726s88+u+B2yo07FABAECQUAEAQFXvk1dLWAhskl3QeffTRXjxr1qyK9qclO++8sxf/85//DNIuSpNczptvW5TRo0envk7y8Vncm2++mbpd1I7evXtH5WIecYW01VZbReXjjz/eq7v77ruj8vz58yvWp+ZwhwIACIKEAgAIgoQCAAii5t/YeP/993txu3btgvTn0Ucf9eKBAwemaifZv759+6buE8KZM2dOzrrkMuHp06envs4OO+yQs45lw/Vh3bp1Ufm9997z6uLb1UvSp59+mrOdSy65xIvzLRv+xS9+4cXxraHi8ymStP322+dsp9K4QwEABEFCAQAEQUIBAARRsTmU+GtzzzzzzEpdNqdXX33Vi9POoaA2DBgwwIvzbStfyudOkvJ9DmXGjBnBroPqeeGFF6Jy/DMpktSlSxcv/sc//hHkmhdddFGQdiqNOxQAQBAkFABAEBV75BVfMvfZZ595dT/+8Y8r1Y3ID37wAy8uZikzak+PHj0KPraUZcLJR2v5lLKLMWrTihUr8sYbO+5QAABBkFAAAEGQUAAAQVRsDuXjjz+OymPHjvXqNttss6jc0pLimTNnRuWjjjoqUO/SS25fP27cuKj8k5/8pNLdQZkNHz489bnxpczJecPzzz8/dbto/W6//XYv3nXXXXMeu2DBAi9+8MEHy9GlVLhDAQAEQUIBAARBQgEABFGxOZS4+HyK5G9XsHr1aq+uoaHBi4cOHRqVH3nkEa+umDmLtm3bFnxsPptvvrkX59vOHOVTzFbxyc+SzJs3r+Bz8221Usy5o0aN8uomTJgQlfn8ysbhN7/5TVQ+6aSTvLrk1vbx10mfdtpp5e1YCbhDAQAEQUIBAARRlUdeSZMmTYrKyTegXXrppV7csWPHqJzcIfivf/1rwddcv369F6fdeiX5+K6YxycIJ/mYKBnHl+wmt1454IAD8p4bl++RV75rStL48eNznov6sO+++0blXXbZxau78sorvXjbbbeNymaWt934/zNf+9rXvLpkXKg777wz1Xn5cIcCAAiChAIACIKEAgAIoibmUOLiyyclae3atV584403VrA3LUv2b9myZVXqCeKS8yJLliyJysm5jXidJN17771RuZjlyMl258yZk7P+uuuu8+pYKlw9Xbt29eL4PEjSiSee6MXJNzjG50U6derk1SXnSZJLg/OJXye+3LgUzKEAAGoWCQUAEAQJBQAQhBXzHM/MCj+4TBobG3PWDRs2rOBj27Txc2naz6Ek24l/xuGEE05I1WaxnHP5F7FXWS2Mm/j8xbXXXuvVlbIlfTHi8yY1sl398865vardiXzKNXYOPPDAqHz11Vd7dcW85rkYyTmU+LxZfPupcnr88cejcomfi2p27HCHAgAIgoQCAAii1T3yyqdz585evPXWW6dq56yzzvLifG+R5JFXy2p93CSX+8a3V0kuIU3uEpyv7vrrrw/Qu7Kq60de8eXAt956q1d36KGHRuV27dqlvcS/iW/p9Pzzz3t1zzzzjBffdtttUXnx4sXB+lAhPPICAJQPCQUAEAQJBQAQRM1tvVKK999/P29cqA8//NCLk/Mkcck3P44YMSIqL1++3Ku75JJLvDi59T2qo6Wt7wvVCuZMNip/+MMfovKuu+6aup3PPvssKr/++uteXfLf8NixY6Py7373u9TXbK24QwEABEFCAQAEUVePvEJJLqUu5lP08WOTy43jt86StGDBgqgc3+EWQOnib09taGjIedxLL73kxVOmTPHi+KPr3//+94F6V5+4QwEABEFCAQAEQUIBAARRV1uvhHLwwQd78cSJE714t912i8qhdi0uZfsHtl4pn/gycEmaNm1azmOTu8m2AnW99QrKiq1XAADlQ0IBAARBQgEABMEcSgGSb3CbOnVqVE5ufc4cyr+rp3GT/PcS36alR48ele5OqZhDQVrMoQAAyoeEAgAIgq1XChDfwkGSvvnNb0bl3r17e3X3339/VJ48ebJXd/PNN5ehd6ikVrg0GKgY7lAAAEGQUAAAQZBQAABBMIeSwmuvvdZsWSpt+S8AtGbcoQAAgiChAACCIKEAAIIgoQAAgiChAACCIKEAAIIgoQAAgiChAACCIKEAAIIgoQAAgih265Xlkt4oR0eQ2o7V7kABGDe1ibGDtJodO0W9AhgAgFx45AUACIKEAgAIgoQCAAiChAIACIKEAgAIgoQCAAiChAIACIKEAgAIgoQCAAiChAIACIKEAgAIgoQCAAiChAIACKKuE4qZ9TQzZ2bFbtMf4tqvm9nASl8XYTB2kNbGPHZKTihmdpyZzTez1Wb2XrZ8lplZiA6Wi5mtin01mdmaWHxikW3dYWZjA/ZtezN7yMzeyg7MnqHariWMnfBjJ9vm2Wb2f2b2sZk9Z2YHhmy/FjB2anPslJRQzOx8STdIGidpO0nbSjpT0gGS2uU4p20p1wzFOddhw5ekJZKGxv5syobjqvFbhqQmSY9KaqzCtSuCsVMeZravpKskDZPUUdJtkh6olb+7EBg75RFk7DjnUn1lL7haUmMLx90h6SZJj2SPHyhpN0lPS1opaZGkI2PHPy3pe7H4FEmzY7FTZvC8kj1/kv71orC2ksYr85a31yT9MHv8Ji308XVJA7PlQyS9KelCSe9IuivZh1g/dpZ0hqTPJa2TtErSw7E2R0v6m6SPJE2TtFmRf8ebZK/TM+3PqRa/GDvlGzuSjpX051i8ZfZ621f7587Yqf+xU8odyn6S2kuaWcCxJ0i6QlKDpPmSHpb0uKQuks6WNMXMdini2kdI2ltSX0kjJH0r++enZ+v6S9pLmUybxnaSvqzMay7PyHegc+4WSVMkXeMyv2UMjVWPkPRtSTtl+3rKhgozW1mPjyIKxNhR2cbO7yW1NbN9s79ZnirpBWX+k6oHjB3V7tgpJaFsI2m5c+6LWGfnZju8xswOih070zk3xznXJKmfpA6SrnLOrXPOPSlplqTji7j2Vc65lc65JZKeyrYpZf4iJzjnljrnPpR0ZcrvrUnSZc65tc65NSnbkKSJzrm3sn15ONZPOec6Oedml9B2a8bYaVnasfOJpPskzZa0VtJlks5w2V856wBjp2VVGzulJJQPJG0Tf9bnnNvfOdcpWxdve2ms3FXS0uwPeYM3JHUr4trxjPmpMgMlajvRbhrvO+c+S3luXK5+buwYOy1LO3ZOkzRSUh9l5hO+K2mWmXUN0KdawNhpWdXGTikJ5VllsthRBRwbz3BvSepuZvFr95C0LFteLWmLWN12RfTpbUndE+2mkczIXp/MLNmnevntr1IYO7mPL1U/SbOcc4udc03OuUeV+d72D3ydamHs5D6+VP1U4thJnVCccysl/VzSZDMbZmYNZtbGzPopM5mTy3xlsuYFZrapmR0iaaike7L1L0j6jpltYWY7K5M1CzVd0jlmtoOZbSXpp0Wcm8+LkvqYWT8z20zSzxL170rqFehakqTsddpnw/bZuC4wdjyhx84CSUPMrJdlHCbpq5IWBrxG1TB2PDU3dkpaNuycu0bSKEkXKPPNvSvpV8qsVJib45x1yvwgD1dmVcRkSSc7517OHnK9MisX3pX0W2Umngp1q6THlPlB/EXS/cV9R81zzi2WNEbSH5VZ5ZF8BnmbpN7Z57gPFtJmdt351/McskaZ1RuS9HI2rhuMnUjosXOnMv9JPi3pY0kTJX0/9nfU6jF2IjU3djYsewMAoCR1vfUKAKBySCgAgCBIKACAIEgoAIAgSCgAgCCK2tHSzFgSVoOcc7W+ZTfjpjYtd851rnYn8mHs1Kxmxw53KMDGK+0WIUCzY4eEAgAIgoQCAAiChAIACIKEAgAIgoQCAAiChAIACIKEAgAIgoQCAAiChAIACIKEAgAIgoQCAAiChAIACKKo3YaRMWjQoKi89957e3VXXHFFpbsDADWBOxQAQBAkFABAEDzySqGxsTEqv/TSS1XsCULYc889vfiMM85otixJL7/8shc/9NBDUXnSpEle3ZIlS0J1Ea1E165do/Lll1/u1Z166qle7Ny/3h22aNEir+6cc87x4qeeeipUF8uKOxQAQBAkFABAECQUAEAQFn+O1+LBZoUfXEe23HJLL/773/8elRcuXOjVDR48uCJ9inPOWcUvWoRaGzcDBw704vHjx3vx7rvvnqrdlStXevHo0aO9+M4774zK69evT3WNwJ53zu1V7U7kU2tjpyXTp0+PysOGDUvdzrp167z4pptuisrPPfecV/fpp59G5QceeCD1NYvU7NjhDgUAEAQJBQAQBMuGC3DMMcd4cbdu3aLygw8+WOHeII0tttgiKt93331eXYcOHYJco1OnTl7861//2ovjj05/+ctfBrkmqmvAgAFePGTIkCDttmvXzovPPffcnMc+8cQTUbmCj7yaxR0KACAIEgoAIAgSCgAgCOZQCjBq1CgvNvvXKt3k83jUpvjPrJg5k/iSTElasWKFF8fn01py5ZVXRuX40nPJfw6O1uOCCy7w4s033zznsWvWrPHiZcuWReXu3bt7de3bty+4D5988knBx5YbdygAgCBIKACAIEgoAIAgmENpxsiRI724X79+XhzfavrPf/5zJbqEKhk7dqwXJ5+RX3rppQW3Ff8szEUXXeTVMYfSOjU0NBR87NVXX+3FY8aMico/+tGPvLqJEyfmbGft2rVenNw6qJq4QwEABEFCAQAEwSOvrJ49e0bl+PJOSfr888+9+PTTT4/KyaWAqE3xn1N8R1hJGjFiRM7zkstCkztPp/XFF18EaQetxwcffJCzbo899sh7bnxX+AkTJnh1c+fOLalfIXGHAgAIgoQCAAiChAIACII5lKyLL744Knfu3Nmru/baa7143rx5FekTwmlqaorKs2bN8uryzaEkt6QvxYsvvhiVk0tIsfE56aSTovLw4cPzHvvUU09F5eSS81rCHQoAIAgSCgAgCBIKACCIjXYO5eSTT/bi0047LSqvW7fOq7v++usr0idUxpw5c7x4+fLlXrzNNtukajf5KoMbbrjBi+Nzb+vXr091DbRe8TkTyd+K50tf+lLec1vLFk/coQAAgiChAACC2GgeeXXp0sWLzznnHC+Ob21w7rnnenVvv/12+TqGioi/sfHwww/36kItDX7zzTe9eP78+V7MY676k3xcms8+++xT8LGLFy/24ssuu6zgc6uJOxQAQBAkFABAECQUAEAQFp87aPFgs8IPrjGjR4/24uTWF3fffXdUPvXUU726Wn/27Zyzlo+qnloYN8ccc0xUnjFjRkWuOXjwYC9+7LHHKnLdIjzvnNur2p3IpxbGTj5bb721F7/66qtRuaWlwHHJjyoceuihXlxLW9RnNTt2uEMBAARBQgEABEFCAQAEUdefQ9l+++2j8tixY726VatWefFNN90UlWt9zgTFGzp0aMHHXnjhhVH5mWee8eqSrwSOz80kJbckr8E5FJQo+VrfJ598MiofffTRBbdz6623enENzpkUhDsUAEAQJBQAQBB1/cjrvPPOi8rt27f36q655hov5i2M9aV3795e3NjYmPPYhQsXenH88efq1au9umnTpnlx/LFGfHsXSfrud7/rxVOmTInK8TfwAcm3iLZW3KEAAIIgoQAAgiChAACCqKs5lL59+3pxfBv65Nbi48ePr0ifUB277LKLF3fo0CHnsck3cibnTeLuvfdeL44vGz722GO9uk033dSL49uXM4dSHzbZxP8vtHv37qnaOeigg7y4tS4x5w4FABAECQUAEAQJBQAQRF3NoYwcOdKL4883p06d6tV9/PHHFekTquO4444r+NhXXnkl9XWuuOKKqJycQ0mK1ydfn4DWqVevXl685557VqkntYE7FABAECQUAEAQrfqNjfHdhCVp0aJFXrxmzZqovN9++3l1S5YsKV/HKow3Nv67ESNGeHHykWdcfEsUSTr55JMLvk6XLl2i8ttvv5332GXLlkXlHj16FHyNMuKNjSVKLu897LDDUrUTf9OjJH3lK19J3acK4Y2NAIDyIaEAAIIgoQAAgmjVy4aHDBnixZ06dfLiu+66KyrX05wJWjZ79mwvXrFiRVTeaqutvLpDDjnEizt27BiVP/roo/CdQ91o0ybM7+Q9e/b04r328qcnnnvuuSDXKTfuUAAAQZBQAABBtLpHXvFPv5955pleXXKX2BtuuKEifULteeutt7x4wYIFUXnQoEFeXbdu3bz4T3/6U1S++eab815n//33T9tFINK2bVsvTn7inkdeAICNCgkFABAECQUAEESrm0OJLw3u37+/Vzdnzhwvfu211yrRJbQCl19+eVROznsk3+a4++67R+VJkyYF60NyKTNav/j2TiH16dOnLO2WG3coAIAgSCgAgCBIKACAIFrdHMqqVaui8owZM7w6s5rexR1VNHfu3Kjc2Njo1Y0bN86L+/btG+Sa8e3qJWnMmDFB2kXtOOuss7w4Pvex0047FdxOU1OTFz/++OOldaxKuEMBAARBQgEABNGq39iIDN7YWJqGhgYv3nHHHaPy6aef7tUl3wQZf2NjcruXww8/3IsXLlxYUj/LgDc2BhbfJfiee+7x6nr16pXzvPh2P5L0jW98I2zHwuONjQCA8iGhAACCIKEAAIJgDqUOMIeClJhDQVrMoQAAyoeEAgAIgoQCAAiChAIACIKEAgAIgoQCAAiChAIACIKEAgAIgoQCAAiChAIACKLYNzYul/RGOTqC1HZs+ZCqY9zUJsYO0mp27BS1lxcAALnwyAsAEAQJBQAQBAkFABAECQUAEAQJBQAQBAkFABAECQUAEAQJBQAQBAkFABDE/wNQMbijkRDDOQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 6 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "fig = plt.figure()\n",
    "for i in range(6):\n",
    "    plt.subplot(2, 3, i + 1)\n",
    "    plt.tight_layout()\n",
    "    plt.imshow(example_data[i][0], cmap='gray', interpolation='none')\n",
    "    plt.title(\"Ground Truth: {}\".format(example_targets[i]))\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(CNN, self).__init__()\n",
    "        self.conv1 = nn.Sequential(  # (1, 28, 28)\n",
    "            nn.Conv2d(\n",
    "                in_channels=1, # 输入通道数，若图片为RGB则为3通道\n",
    "                out_channels=32, # 输出通道数，即多少个卷积核一起卷积\n",
    "                kernel_size=3, # 卷积核大小\n",
    "                stride=1, # 卷积核移动步长\n",
    "                padding=1, # 边缘增加的像素，使得得到的图片长宽没有变化\n",
    "            ),# (32, 28, 28)\n",
    "            nn.BatchNorm2d(32),\n",
    "            nn.ReLU(inplace=True),\n",
    "        )\n",
    "        self.conv2 = nn.Sequential(\n",
    "            nn.Conv2d(32, 32, 3, 1, 1), # (32, 28, 28)\n",
    "            nn.BatchNorm2d(32),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(kernel_size=2), # 池化 (32, 14, 14)\n",
    "        )\n",
    "        self.conv3 = nn.Sequential(# (32, 14, 14)\n",
    "            nn.Conv2d(32, 64, 3, 1, 1),# (64, 14, 14)\n",
    "            nn.BatchNorm2d(64),\n",
    "            nn.ReLU(inplace=True),\n",
    "        )\n",
    "        self.conv4 = nn.Sequential(\n",
    "            nn.Conv2d(64, 64, 3, 1, 1),# (64, 14, 14)\n",
    "            nn.BatchNorm2d(64),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(2),# (64, 7, 7)\n",
    "        )\n",
    "        self.out = nn.Sequential(\n",
    "            nn.Dropout(p = 0.5), # 抑制过拟合\n",
    "            nn.Linear(64 * 7 * 7, 512),\n",
    "            nn.BatchNorm1d(512),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Dropout(p = 0.5),\n",
    "            nn.Linear(512, 512),\n",
    "            nn.BatchNorm1d(512),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Dropout(p = 0.5),\n",
    "            nn.Linear(512, 10),\n",
    "        )\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.conv2(x)\n",
    "        x = self.conv3(x)\n",
    "        x = self.conv4(x)\n",
    "        x = x.view(x.size(0), -1) # (batch_size, 64*7*7)\n",
    "        output = self.out(x)\n",
    "        return F.log_softmax(output)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "network = CNN()\n",
    "optimizer = optim.SGD(network.parameters(), lr=learning_rate,\n",
    "                      momentum=momentum)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CNN(\n",
       "  (conv1): Sequential(\n",
       "    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): ReLU(inplace=True)\n",
       "  )\n",
       "  (conv2): Sequential(\n",
       "    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): ReLU(inplace=True)\n",
       "    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (conv3): Sequential(\n",
       "    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): ReLU(inplace=True)\n",
       "  )\n",
       "  (conv4): Sequential(\n",
       "    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): ReLU(inplace=True)\n",
       "    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (out): Sequential(\n",
       "    (0): Dropout(p=0.5, inplace=False)\n",
       "    (1): Linear(in_features=3136, out_features=512, bias=True)\n",
       "    (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (3): ReLU(inplace=True)\n",
       "    (4): Dropout(p=0.5, inplace=False)\n",
       "    (5): Linear(in_features=512, out_features=512, bias=True)\n",
       "    (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (7): ReLU(inplace=True)\n",
       "    (8): Dropout(p=0.5, inplace=False)\n",
       "    (9): Linear(in_features=512, out_features=10, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "network"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 网络结构\n",
    "- 先是四个卷积层，每层卷积层之后跟一个BN层和ReLu激活函数，第二个卷积层使用了MaxPolling\n",
    "- 再是三个全连接层，"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of paramerters in networks is 1941354  \n"
     ]
    }
   ],
   "source": [
    "print(\"Total number of paramerters in networks is {}  \".format(sum(x.numel() for x in network.parameters())))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_losses = []\n",
    "train_counter = []\n",
    "test_losses = []\n",
    "test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-14-476dee2bb741>:51: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  return F.log_softmax(output)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.333150\n",
      "Train Epoch: 1 [640/60000 (1%)]\tLoss: 2.037497\n",
      "Train Epoch: 1 [1280/60000 (2%)]\tLoss: 1.840573\n",
      "Train Epoch: 1 [1920/60000 (3%)]\tLoss: 1.425134\n",
      "Train Epoch: 1 [2560/60000 (4%)]\tLoss: 1.287103\n",
      "Train Epoch: 1 [3200/60000 (5%)]\tLoss: 1.171738\n",
      "Train Epoch: 1 [3840/60000 (6%)]\tLoss: 0.942139\n",
      "Train Epoch: 1 [4480/60000 (7%)]\tLoss: 0.725613\n",
      "Train Epoch: 1 [5120/60000 (9%)]\tLoss: 0.507201\n",
      "Train Epoch: 1 [5760/60000 (10%)]\tLoss: 0.497014\n",
      "Train Epoch: 1 [6400/60000 (11%)]\tLoss: 0.523501\n",
      "Train Epoch: 1 [7040/60000 (12%)]\tLoss: 0.489296\n",
      "Train Epoch: 1 [7680/60000 (13%)]\tLoss: 0.546101\n",
      "Train Epoch: 1 [8320/60000 (14%)]\tLoss: 0.394666\n",
      "Train Epoch: 1 [8960/60000 (15%)]\tLoss: 0.445805\n",
      "Train Epoch: 1 [9600/60000 (16%)]\tLoss: 0.338042\n",
      "Train Epoch: 1 [10240/60000 (17%)]\tLoss: 0.302797\n",
      "Train Epoch: 1 [10880/60000 (18%)]\tLoss: 0.369880\n",
      "Train Epoch: 1 [11520/60000 (19%)]\tLoss: 0.427279\n",
      "Train Epoch: 1 [12160/60000 (20%)]\tLoss: 0.274364\n",
      "Train Epoch: 1 [12800/60000 (21%)]\tLoss: 0.293333\n",
      "Train Epoch: 1 [13440/60000 (22%)]\tLoss: 0.395420\n",
      "Train Epoch: 1 [14080/60000 (23%)]\tLoss: 0.221817\n",
      "Train Epoch: 1 [14720/60000 (25%)]\tLoss: 0.272894\n",
      "Train Epoch: 1 [15360/60000 (26%)]\tLoss: 0.226839\n",
      "Train Epoch: 1 [16000/60000 (27%)]\tLoss: 0.338889\n",
      "Train Epoch: 1 [16640/60000 (28%)]\tLoss: 0.202905\n",
      "Train Epoch: 1 [17280/60000 (29%)]\tLoss: 0.298363\n",
      "Train Epoch: 1 [17920/60000 (30%)]\tLoss: 0.213757\n",
      "Train Epoch: 1 [18560/60000 (31%)]\tLoss: 0.249515\n",
      "Train Epoch: 1 [19200/60000 (32%)]\tLoss: 0.183747\n",
      "Train Epoch: 1 [19840/60000 (33%)]\tLoss: 0.363137\n",
      "Train Epoch: 1 [20480/60000 (34%)]\tLoss: 0.315455\n",
      "Train Epoch: 1 [21120/60000 (35%)]\tLoss: 0.239646\n",
      "Train Epoch: 1 [21760/60000 (36%)]\tLoss: 0.218483\n",
      "Train Epoch: 1 [22400/60000 (37%)]\tLoss: 0.309400\n",
      "Train Epoch: 1 [23040/60000 (38%)]\tLoss: 0.166086\n",
      "Train Epoch: 1 [23680/60000 (39%)]\tLoss: 0.318413\n",
      "Train Epoch: 1 [24320/60000 (41%)]\tLoss: 0.111112\n",
      "Train Epoch: 1 [24960/60000 (42%)]\tLoss: 0.086917\n",
      "Train Epoch: 1 [25600/60000 (43%)]\tLoss: 0.107234\n",
      "Train Epoch: 1 [26240/60000 (44%)]\tLoss: 0.229634\n",
      "Train Epoch: 1 [26880/60000 (45%)]\tLoss: 0.133027\n",
      "Train Epoch: 1 [27520/60000 (46%)]\tLoss: 0.198194\n",
      "Train Epoch: 1 [28160/60000 (47%)]\tLoss: 0.220334\n",
      "Train Epoch: 1 [28800/60000 (48%)]\tLoss: 0.115782\n",
      "Train Epoch: 1 [29440/60000 (49%)]\tLoss: 0.211294\n",
      "Train Epoch: 1 [30080/60000 (50%)]\tLoss: 0.252153\n",
      "Train Epoch: 1 [30720/60000 (51%)]\tLoss: 0.159103\n",
      "Train Epoch: 1 [31360/60000 (52%)]\tLoss: 0.223440\n",
      "Train Epoch: 1 [32000/60000 (53%)]\tLoss: 0.205165\n",
      "Train Epoch: 1 [32640/60000 (54%)]\tLoss: 0.247041\n",
      "Train Epoch: 1 [33280/60000 (55%)]\tLoss: 0.128686\n",
      "Train Epoch: 1 [33920/60000 (57%)]\tLoss: 0.101456\n",
      "Train Epoch: 1 [34560/60000 (58%)]\tLoss: 0.091019\n",
      "Train Epoch: 1 [35200/60000 (59%)]\tLoss: 0.373986\n",
      "Train Epoch: 1 [35840/60000 (60%)]\tLoss: 0.139531\n",
      "Train Epoch: 1 [36480/60000 (61%)]\tLoss: 0.174432\n",
      "Train Epoch: 1 [37120/60000 (62%)]\tLoss: 0.124765\n",
      "Train Epoch: 1 [37760/60000 (63%)]\tLoss: 0.190884\n",
      "Train Epoch: 1 [38400/60000 (64%)]\tLoss: 0.260908\n",
      "Train Epoch: 1 [39040/60000 (65%)]\tLoss: 0.161259\n",
      "Train Epoch: 1 [39680/60000 (66%)]\tLoss: 0.167599\n",
      "Train Epoch: 1 [40320/60000 (67%)]\tLoss: 0.122063\n",
      "Train Epoch: 1 [40960/60000 (68%)]\tLoss: 0.142709\n",
      "Train Epoch: 1 [41600/60000 (69%)]\tLoss: 0.064887\n",
      "Train Epoch: 1 [42240/60000 (70%)]\tLoss: 0.119196\n",
      "Train Epoch: 1 [42880/60000 (71%)]\tLoss: 0.195196\n",
      "Train Epoch: 1 [43520/60000 (72%)]\tLoss: 0.042703\n",
      "Train Epoch: 1 [44160/60000 (74%)]\tLoss: 0.101309\n",
      "Train Epoch: 1 [44800/60000 (75%)]\tLoss: 0.102936\n",
      "Train Epoch: 1 [45440/60000 (76%)]\tLoss: 0.117652\n",
      "Train Epoch: 1 [46080/60000 (77%)]\tLoss: 0.087879\n",
      "Train Epoch: 1 [46720/60000 (78%)]\tLoss: 0.039908\n",
      "Train Epoch: 1 [47360/60000 (79%)]\tLoss: 0.173407\n",
      "Train Epoch: 1 [48000/60000 (80%)]\tLoss: 0.097155\n",
      "Train Epoch: 1 [48640/60000 (81%)]\tLoss: 0.031641\n",
      "Train Epoch: 1 [49280/60000 (82%)]\tLoss: 0.103646\n",
      "Train Epoch: 1 [49920/60000 (83%)]\tLoss: 0.119902\n",
      "Train Epoch: 1 [50560/60000 (84%)]\tLoss: 0.109641\n",
      "Train Epoch: 1 [51200/60000 (85%)]\tLoss: 0.135710\n",
      "Train Epoch: 1 [51840/60000 (86%)]\tLoss: 0.186560\n",
      "Train Epoch: 1 [52480/60000 (87%)]\tLoss: 0.045716\n",
      "Train Epoch: 1 [53120/60000 (88%)]\tLoss: 0.052023\n",
      "Train Epoch: 1 [53760/60000 (90%)]\tLoss: 0.173905\n",
      "Train Epoch: 1 [54400/60000 (91%)]\tLoss: 0.065863\n",
      "Train Epoch: 1 [55040/60000 (92%)]\tLoss: 0.231615\n",
      "Train Epoch: 1 [55680/60000 (93%)]\tLoss: 0.095900\n",
      "Train Epoch: 1 [56320/60000 (94%)]\tLoss: 0.048545\n",
      "Train Epoch: 1 [56960/60000 (95%)]\tLoss: 0.172763\n",
      "Train Epoch: 1 [57600/60000 (96%)]\tLoss: 0.062961\n",
      "Train Epoch: 1 [58240/60000 (97%)]\tLoss: 0.038528\n",
      "Train Epoch: 1 [58880/60000 (98%)]\tLoss: 0.030154\n",
      "Train Epoch: 1 [59520/60000 (99%)]\tLoss: 0.087513\n"
     ]
    }
   ],
   "source": [
    "\n",
    "def train(epoch):\n",
    "    network.train()\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        optimizer.zero_grad()\n",
    "        output = network(data)\n",
    "        loss = F.nll_loss(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if batch_idx % log_interval == 0:\n",
    "            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                epoch, batch_idx * len(data), len(train_loader.dataset),\n",
    "                       100. * batch_idx / len(train_loader), loss.item()))\n",
    "            train_losses.append(loss.item())\n",
    "            train_counter.append(\n",
    "                (batch_idx * 64) + ((epoch - 1) * len(train_loader.dataset)))\n",
    "            torch.save(network.state_dict(), './model.pth')\n",
    "            torch.save(optimizer.state_dict(), './optimizer.pth')\n",
    "\n",
    "# train(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-14-476dee2bb741>:51: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  return F.log_softmax(output)\n",
      "/home/sychen/venv_envirment/tf_venv/lib/python3.8/site-packages/torch/nn/_reduction.py:44: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test set: Avg. loss: 0.0512, Accuracy: 9834/10000 (98%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "def test():\n",
    "    network.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            output = network(data)\n",
    "            test_loss += F.nll_loss(output, target, size_average=False).item()\n",
    "            pred = output.data.max(1, keepdim=True)[1]\n",
    "            correct += pred.eq(target.data.view_as(pred)).sum()\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "    test_losses.append(test_loss)\n",
    "    print('\\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "        test_loss, correct, len(test_loader.dataset),\n",
    "        100. * correct / len(test_loader.dataset)))\n",
    "\n",
    "\n",
    "test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-14-476dee2bb741>:51: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  return F.log_softmax(output)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test set: Avg. loss: 0.0512, Accuracy: 9834/10000 (98%)\n",
      "\n",
      "Train Epoch: 1 [0/60000 (0%)]\tLoss: 0.143411\n",
      "Train Epoch: 1 [640/60000 (1%)]\tLoss: 0.127005\n",
      "Train Epoch: 1 [1280/60000 (2%)]\tLoss: 0.147649\n",
      "Train Epoch: 1 [1920/60000 (3%)]\tLoss: 0.071453\n",
      "Train Epoch: 1 [2560/60000 (4%)]\tLoss: 0.212009\n",
      "Train Epoch: 1 [3200/60000 (5%)]\tLoss: 0.070162\n",
      "Train Epoch: 1 [3840/60000 (6%)]\tLoss: 0.229099\n",
      "Train Epoch: 1 [4480/60000 (7%)]\tLoss: 0.069427\n",
      "Train Epoch: 1 [5120/60000 (9%)]\tLoss: 0.120209\n",
      "Train Epoch: 1 [5760/60000 (10%)]\tLoss: 0.106318\n",
      "Train Epoch: 1 [6400/60000 (11%)]\tLoss: 0.057732\n",
      "Train Epoch: 1 [7040/60000 (12%)]\tLoss: 0.146903\n",
      "Train Epoch: 1 [7680/60000 (13%)]\tLoss: 0.094874\n",
      "Train Epoch: 1 [8320/60000 (14%)]\tLoss: 0.152634\n",
      "Train Epoch: 1 [8960/60000 (15%)]\tLoss: 0.086502\n",
      "Train Epoch: 1 [9600/60000 (16%)]\tLoss: 0.026266\n",
      "Train Epoch: 1 [10240/60000 (17%)]\tLoss: 0.204389\n",
      "Train Epoch: 1 [10880/60000 (18%)]\tLoss: 0.145320\n",
      "Train Epoch: 1 [11520/60000 (19%)]\tLoss: 0.025304\n",
      "Train Epoch: 1 [12160/60000 (20%)]\tLoss: 0.059063\n",
      "Train Epoch: 1 [12800/60000 (21%)]\tLoss: 0.097810\n",
      "Train Epoch: 1 [13440/60000 (22%)]\tLoss: 0.102994\n",
      "Train Epoch: 1 [14080/60000 (23%)]\tLoss: 0.028142\n",
      "Train Epoch: 1 [14720/60000 (25%)]\tLoss: 0.223580\n",
      "Train Epoch: 1 [15360/60000 (26%)]\tLoss: 0.041273\n",
      "Train Epoch: 1 [16000/60000 (27%)]\tLoss: 0.099279\n",
      "Train Epoch: 1 [16640/60000 (28%)]\tLoss: 0.167446\n",
      "Train Epoch: 1 [17280/60000 (29%)]\tLoss: 0.101908\n",
      "Train Epoch: 1 [17920/60000 (30%)]\tLoss: 0.129389\n",
      "Train Epoch: 1 [18560/60000 (31%)]\tLoss: 0.106327\n",
      "Train Epoch: 1 [19200/60000 (32%)]\tLoss: 0.063659\n",
      "Train Epoch: 1 [19840/60000 (33%)]\tLoss: 0.098349\n",
      "Train Epoch: 1 [20480/60000 (34%)]\tLoss: 0.050963\n",
      "Train Epoch: 1 [21120/60000 (35%)]\tLoss: 0.036312\n",
      "Train Epoch: 1 [21760/60000 (36%)]\tLoss: 0.018895\n",
      "Train Epoch: 1 [22400/60000 (37%)]\tLoss: 0.188297\n",
      "Train Epoch: 1 [23040/60000 (38%)]\tLoss: 0.157826\n",
      "Train Epoch: 1 [23680/60000 (39%)]\tLoss: 0.119203\n",
      "Train Epoch: 1 [24320/60000 (41%)]\tLoss: 0.220768\n",
      "Train Epoch: 1 [24960/60000 (42%)]\tLoss: 0.080389\n",
      "Train Epoch: 1 [25600/60000 (43%)]\tLoss: 0.115253\n",
      "Train Epoch: 1 [26240/60000 (44%)]\tLoss: 0.045889\n",
      "Train Epoch: 1 [26880/60000 (45%)]\tLoss: 0.082754\n",
      "Train Epoch: 1 [27520/60000 (46%)]\tLoss: 0.060592\n",
      "Train Epoch: 1 [28160/60000 (47%)]\tLoss: 0.037155\n",
      "Train Epoch: 1 [28800/60000 (48%)]\tLoss: 0.191909\n",
      "Train Epoch: 1 [29440/60000 (49%)]\tLoss: 0.067434\n",
      "Train Epoch: 1 [30080/60000 (50%)]\tLoss: 0.149938\n",
      "Train Epoch: 1 [30720/60000 (51%)]\tLoss: 0.103598\n",
      "Train Epoch: 1 [31360/60000 (52%)]\tLoss: 0.088378\n",
      "Train Epoch: 1 [32000/60000 (53%)]\tLoss: 0.048688\n",
      "Train Epoch: 1 [32640/60000 (54%)]\tLoss: 0.125456\n",
      "Train Epoch: 1 [33280/60000 (55%)]\tLoss: 0.096177\n",
      "Train Epoch: 1 [33920/60000 (57%)]\tLoss: 0.040322\n",
      "Train Epoch: 1 [34560/60000 (58%)]\tLoss: 0.112986\n",
      "Train Epoch: 1 [35200/60000 (59%)]\tLoss: 0.138497\n",
      "Train Epoch: 1 [35840/60000 (60%)]\tLoss: 0.132050\n",
      "Train Epoch: 1 [36480/60000 (61%)]\tLoss: 0.067831\n",
      "Train Epoch: 1 [37120/60000 (62%)]\tLoss: 0.179142\n",
      "Train Epoch: 1 [37760/60000 (63%)]\tLoss: 0.074759\n",
      "Train Epoch: 1 [38400/60000 (64%)]\tLoss: 0.033656\n",
      "Train Epoch: 1 [39040/60000 (65%)]\tLoss: 0.111598\n",
      "Train Epoch: 1 [39680/60000 (66%)]\tLoss: 0.112150\n",
      "Train Epoch: 1 [40320/60000 (67%)]\tLoss: 0.077940\n",
      "Train Epoch: 1 [40960/60000 (68%)]\tLoss: 0.059135\n",
      "Train Epoch: 1 [41600/60000 (69%)]\tLoss: 0.040010\n",
      "Train Epoch: 1 [42240/60000 (70%)]\tLoss: 0.019285\n",
      "Train Epoch: 1 [42880/60000 (71%)]\tLoss: 0.122847\n",
      "Train Epoch: 1 [43520/60000 (72%)]\tLoss: 0.020630\n",
      "Train Epoch: 1 [44160/60000 (74%)]\tLoss: 0.090358\n",
      "Train Epoch: 1 [44800/60000 (75%)]\tLoss: 0.035513\n",
      "Train Epoch: 1 [45440/60000 (76%)]\tLoss: 0.138427\n",
      "Train Epoch: 1 [46080/60000 (77%)]\tLoss: 0.120700\n",
      "Train Epoch: 1 [46720/60000 (78%)]\tLoss: 0.060035\n",
      "Train Epoch: 1 [47360/60000 (79%)]\tLoss: 0.090507\n",
      "Train Epoch: 1 [48000/60000 (80%)]\tLoss: 0.151025\n",
      "Train Epoch: 1 [48640/60000 (81%)]\tLoss: 0.061816\n",
      "Train Epoch: 1 [49280/60000 (82%)]\tLoss: 0.086713\n",
      "Train Epoch: 1 [49920/60000 (83%)]\tLoss: 0.018001\n",
      "Train Epoch: 1 [50560/60000 (84%)]\tLoss: 0.127845\n",
      "Train Epoch: 1 [51200/60000 (85%)]\tLoss: 0.040950\n",
      "Train Epoch: 1 [51840/60000 (86%)]\tLoss: 0.113659\n",
      "Train Epoch: 1 [52480/60000 (87%)]\tLoss: 0.237763\n",
      "Train Epoch: 1 [53120/60000 (88%)]\tLoss: 0.059406\n",
      "Train Epoch: 1 [53760/60000 (90%)]\tLoss: 0.086740\n",
      "Train Epoch: 1 [54400/60000 (91%)]\tLoss: 0.088749\n",
      "Train Epoch: 1 [55040/60000 (92%)]\tLoss: 0.110734\n",
      "Train Epoch: 1 [55680/60000 (93%)]\tLoss: 0.164675\n",
      "Train Epoch: 1 [56320/60000 (94%)]\tLoss: 0.015013\n",
      "Train Epoch: 1 [56960/60000 (95%)]\tLoss: 0.035054\n",
      "Train Epoch: 1 [57600/60000 (96%)]\tLoss: 0.061095\n",
      "Train Epoch: 1 [58240/60000 (97%)]\tLoss: 0.026013\n",
      "Train Epoch: 1 [58880/60000 (98%)]\tLoss: 0.052296\n",
      "Train Epoch: 1 [59520/60000 (99%)]\tLoss: 0.189326\n",
      "\n",
      "Test set: Avg. loss: 0.0319, Accuracy: 9889/10000 (99%)\n",
      "\n",
      "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.044756\n",
      "Train Epoch: 2 [640/60000 (1%)]\tLoss: 0.078974\n",
      "Train Epoch: 2 [1280/60000 (2%)]\tLoss: 0.097580\n",
      "Train Epoch: 2 [1920/60000 (3%)]\tLoss: 0.073131\n",
      "Train Epoch: 2 [2560/60000 (4%)]\tLoss: 0.059036\n",
      "Train Epoch: 2 [3200/60000 (5%)]\tLoss: 0.078063\n",
      "Train Epoch: 2 [3840/60000 (6%)]\tLoss: 0.078881\n",
      "Train Epoch: 2 [4480/60000 (7%)]\tLoss: 0.047245\n",
      "Train Epoch: 2 [5120/60000 (9%)]\tLoss: 0.056273\n",
      "Train Epoch: 2 [5760/60000 (10%)]\tLoss: 0.037997\n",
      "Train Epoch: 2 [6400/60000 (11%)]\tLoss: 0.037513\n",
      "Train Epoch: 2 [7040/60000 (12%)]\tLoss: 0.130159\n",
      "Train Epoch: 2 [7680/60000 (13%)]\tLoss: 0.035135\n",
      "Train Epoch: 2 [8320/60000 (14%)]\tLoss: 0.170816\n",
      "Train Epoch: 2 [8960/60000 (15%)]\tLoss: 0.047632\n",
      "Train Epoch: 2 [9600/60000 (16%)]\tLoss: 0.032222\n",
      "Train Epoch: 2 [10240/60000 (17%)]\tLoss: 0.022743\n",
      "Train Epoch: 2 [10880/60000 (18%)]\tLoss: 0.070120\n",
      "Train Epoch: 2 [11520/60000 (19%)]\tLoss: 0.107791\n",
      "Train Epoch: 2 [12160/60000 (20%)]\tLoss: 0.014869\n",
      "Train Epoch: 2 [12800/60000 (21%)]\tLoss: 0.063113\n",
      "Train Epoch: 2 [13440/60000 (22%)]\tLoss: 0.078367\n",
      "Train Epoch: 2 [14080/60000 (23%)]\tLoss: 0.061224\n",
      "Train Epoch: 2 [14720/60000 (25%)]\tLoss: 0.118822\n",
      "Train Epoch: 2 [15360/60000 (26%)]\tLoss: 0.053443\n",
      "Train Epoch: 2 [16000/60000 (27%)]\tLoss: 0.088248\n",
      "Train Epoch: 2 [16640/60000 (28%)]\tLoss: 0.074043\n",
      "Train Epoch: 2 [17280/60000 (29%)]\tLoss: 0.031599\n",
      "Train Epoch: 2 [17920/60000 (30%)]\tLoss: 0.130468\n",
      "Train Epoch: 2 [18560/60000 (31%)]\tLoss: 0.028004\n",
      "Train Epoch: 2 [19200/60000 (32%)]\tLoss: 0.053524\n",
      "Train Epoch: 2 [19840/60000 (33%)]\tLoss: 0.024214\n",
      "Train Epoch: 2 [20480/60000 (34%)]\tLoss: 0.081928\n",
      "Train Epoch: 2 [21120/60000 (35%)]\tLoss: 0.090345\n",
      "Train Epoch: 2 [21760/60000 (36%)]\tLoss: 0.101643\n",
      "Train Epoch: 2 [22400/60000 (37%)]\tLoss: 0.073638\n",
      "Train Epoch: 2 [23040/60000 (38%)]\tLoss: 0.141650\n",
      "Train Epoch: 2 [23680/60000 (39%)]\tLoss: 0.078445\n",
      "Train Epoch: 2 [24320/60000 (41%)]\tLoss: 0.075029\n",
      "Train Epoch: 2 [24960/60000 (42%)]\tLoss: 0.052604\n",
      "Train Epoch: 2 [25600/60000 (43%)]\tLoss: 0.076562\n",
      "Train Epoch: 2 [26240/60000 (44%)]\tLoss: 0.122360\n",
      "Train Epoch: 2 [26880/60000 (45%)]\tLoss: 0.036844\n",
      "Train Epoch: 2 [27520/60000 (46%)]\tLoss: 0.038730\n",
      "Train Epoch: 2 [28160/60000 (47%)]\tLoss: 0.154715\n",
      "Train Epoch: 2 [28800/60000 (48%)]\tLoss: 0.099259\n",
      "Train Epoch: 2 [29440/60000 (49%)]\tLoss: 0.037651\n",
      "Train Epoch: 2 [30080/60000 (50%)]\tLoss: 0.046535\n",
      "Train Epoch: 2 [30720/60000 (51%)]\tLoss: 0.152425\n",
      "Train Epoch: 2 [31360/60000 (52%)]\tLoss: 0.087176\n",
      "Train Epoch: 2 [32000/60000 (53%)]\tLoss: 0.102725\n",
      "Train Epoch: 2 [32640/60000 (54%)]\tLoss: 0.096574\n",
      "Train Epoch: 2 [33280/60000 (55%)]\tLoss: 0.009956\n",
      "Train Epoch: 2 [33920/60000 (57%)]\tLoss: 0.106885\n",
      "Train Epoch: 2 [34560/60000 (58%)]\tLoss: 0.071458\n",
      "Train Epoch: 2 [35200/60000 (59%)]\tLoss: 0.041248\n",
      "Train Epoch: 2 [35840/60000 (60%)]\tLoss: 0.082687\n",
      "Train Epoch: 2 [36480/60000 (61%)]\tLoss: 0.086758\n",
      "Train Epoch: 2 [37120/60000 (62%)]\tLoss: 0.054305\n",
      "Train Epoch: 2 [37760/60000 (63%)]\tLoss: 0.041540\n",
      "Train Epoch: 2 [38400/60000 (64%)]\tLoss: 0.021576\n",
      "Train Epoch: 2 [39040/60000 (65%)]\tLoss: 0.159315\n",
      "Train Epoch: 2 [39680/60000 (66%)]\tLoss: 0.081624\n",
      "Train Epoch: 2 [40320/60000 (67%)]\tLoss: 0.039511\n",
      "Train Epoch: 2 [40960/60000 (68%)]\tLoss: 0.043043\n",
      "Train Epoch: 2 [41600/60000 (69%)]\tLoss: 0.103599\n",
      "Train Epoch: 2 [42240/60000 (70%)]\tLoss: 0.178088\n",
      "Train Epoch: 2 [42880/60000 (71%)]\tLoss: 0.128928\n",
      "Train Epoch: 2 [43520/60000 (72%)]\tLoss: 0.101569\n",
      "Train Epoch: 2 [44160/60000 (74%)]\tLoss: 0.060398\n",
      "Train Epoch: 2 [44800/60000 (75%)]\tLoss: 0.047167\n",
      "Train Epoch: 2 [45440/60000 (76%)]\tLoss: 0.018755\n",
      "Train Epoch: 2 [46080/60000 (77%)]\tLoss: 0.016992\n",
      "Train Epoch: 2 [46720/60000 (78%)]\tLoss: 0.069840\n",
      "Train Epoch: 2 [47360/60000 (79%)]\tLoss: 0.088409\n",
      "Train Epoch: 2 [48000/60000 (80%)]\tLoss: 0.054767\n",
      "Train Epoch: 2 [48640/60000 (81%)]\tLoss: 0.263176\n",
      "Train Epoch: 2 [49280/60000 (82%)]\tLoss: 0.085914\n",
      "Train Epoch: 2 [49920/60000 (83%)]\tLoss: 0.099500\n",
      "Train Epoch: 2 [50560/60000 (84%)]\tLoss: 0.198909\n",
      "Train Epoch: 2 [51200/60000 (85%)]\tLoss: 0.025325\n",
      "Train Epoch: 2 [51840/60000 (86%)]\tLoss: 0.076238\n",
      "Train Epoch: 2 [52480/60000 (87%)]\tLoss: 0.170034\n",
      "Train Epoch: 2 [53120/60000 (88%)]\tLoss: 0.076846\n",
      "Train Epoch: 2 [53760/60000 (90%)]\tLoss: 0.065312\n",
      "Train Epoch: 2 [54400/60000 (91%)]\tLoss: 0.053269\n",
      "Train Epoch: 2 [55040/60000 (92%)]\tLoss: 0.117017\n",
      "Train Epoch: 2 [55680/60000 (93%)]\tLoss: 0.029505\n",
      "Train Epoch: 2 [56320/60000 (94%)]\tLoss: 0.017016\n",
      "Train Epoch: 2 [56960/60000 (95%)]\tLoss: 0.020341\n",
      "Train Epoch: 2 [57600/60000 (96%)]\tLoss: 0.054202\n",
      "Train Epoch: 2 [58240/60000 (97%)]\tLoss: 0.019319\n",
      "Train Epoch: 2 [58880/60000 (98%)]\tLoss: 0.037220\n",
      "Train Epoch: 2 [59520/60000 (99%)]\tLoss: 0.047444\n",
      "\n",
      "Test set: Avg. loss: 0.0277, Accuracy: 9907/10000 (99%)\n",
      "\n",
      "Train Epoch: 3 [0/60000 (0%)]\tLoss: 0.050359\n",
      "Train Epoch: 3 [640/60000 (1%)]\tLoss: 0.109776\n",
      "Train Epoch: 3 [1280/60000 (2%)]\tLoss: 0.105801\n",
      "Train Epoch: 3 [1920/60000 (3%)]\tLoss: 0.028432\n",
      "Train Epoch: 3 [2560/60000 (4%)]\tLoss: 0.079812\n",
      "Train Epoch: 3 [3200/60000 (5%)]\tLoss: 0.144449\n",
      "Train Epoch: 3 [3840/60000 (6%)]\tLoss: 0.032163\n",
      "Train Epoch: 3 [4480/60000 (7%)]\tLoss: 0.052444\n",
      "Train Epoch: 3 [5120/60000 (9%)]\tLoss: 0.034681\n",
      "Train Epoch: 3 [5760/60000 (10%)]\tLoss: 0.130200\n",
      "Train Epoch: 3 [6400/60000 (11%)]\tLoss: 0.099211\n",
      "Train Epoch: 3 [7040/60000 (12%)]\tLoss: 0.069577\n",
      "Train Epoch: 3 [7680/60000 (13%)]\tLoss: 0.012800\n",
      "Train Epoch: 3 [8320/60000 (14%)]\tLoss: 0.101536\n",
      "Train Epoch: 3 [8960/60000 (15%)]\tLoss: 0.065700\n",
      "Train Epoch: 3 [9600/60000 (16%)]\tLoss: 0.064023\n",
      "Train Epoch: 3 [10240/60000 (17%)]\tLoss: 0.112882\n",
      "Train Epoch: 3 [10880/60000 (18%)]\tLoss: 0.104735\n",
      "Train Epoch: 3 [11520/60000 (19%)]\tLoss: 0.099682\n",
      "Train Epoch: 3 [12160/60000 (20%)]\tLoss: 0.129080\n",
      "Train Epoch: 3 [12800/60000 (21%)]\tLoss: 0.049498\n",
      "Train Epoch: 3 [13440/60000 (22%)]\tLoss: 0.014903\n",
      "Train Epoch: 3 [14080/60000 (23%)]\tLoss: 0.017465\n",
      "Train Epoch: 3 [14720/60000 (25%)]\tLoss: 0.047198\n",
      "Train Epoch: 3 [15360/60000 (26%)]\tLoss: 0.102332\n",
      "Train Epoch: 3 [16000/60000 (27%)]\tLoss: 0.009866\n",
      "Train Epoch: 3 [16640/60000 (28%)]\tLoss: 0.038312\n",
      "Train Epoch: 3 [17280/60000 (29%)]\tLoss: 0.022938\n",
      "Train Epoch: 3 [17920/60000 (30%)]\tLoss: 0.034547\n",
      "Train Epoch: 3 [18560/60000 (31%)]\tLoss: 0.023827\n",
      "Train Epoch: 3 [19200/60000 (32%)]\tLoss: 0.144600\n",
      "Train Epoch: 3 [19840/60000 (33%)]\tLoss: 0.086569\n",
      "Train Epoch: 3 [20480/60000 (34%)]\tLoss: 0.200349\n",
      "Train Epoch: 3 [21120/60000 (35%)]\tLoss: 0.076869\n",
      "Train Epoch: 3 [21760/60000 (36%)]\tLoss: 0.033328\n",
      "Train Epoch: 3 [22400/60000 (37%)]\tLoss: 0.020793\n",
      "Train Epoch: 3 [23040/60000 (38%)]\tLoss: 0.011943\n",
      "Train Epoch: 3 [23680/60000 (39%)]\tLoss: 0.078258\n",
      "Train Epoch: 3 [24320/60000 (41%)]\tLoss: 0.135668\n",
      "Train Epoch: 3 [24960/60000 (42%)]\tLoss: 0.111256\n",
      "Train Epoch: 3 [25600/60000 (43%)]\tLoss: 0.143342\n",
      "Train Epoch: 3 [26240/60000 (44%)]\tLoss: 0.096037\n",
      "Train Epoch: 3 [26880/60000 (45%)]\tLoss: 0.018426\n",
      "Train Epoch: 3 [27520/60000 (46%)]\tLoss: 0.018491\n",
      "Train Epoch: 3 [28160/60000 (47%)]\tLoss: 0.046770\n",
      "Train Epoch: 3 [28800/60000 (48%)]\tLoss: 0.170885\n",
      "Train Epoch: 3 [29440/60000 (49%)]\tLoss: 0.030876\n",
      "Train Epoch: 3 [30080/60000 (50%)]\tLoss: 0.062226\n",
      "Train Epoch: 3 [30720/60000 (51%)]\tLoss: 0.018752\n",
      "Train Epoch: 3 [31360/60000 (52%)]\tLoss: 0.151010\n",
      "Train Epoch: 3 [32000/60000 (53%)]\tLoss: 0.031047\n",
      "Train Epoch: 3 [32640/60000 (54%)]\tLoss: 0.068647\n",
      "Train Epoch: 3 [33280/60000 (55%)]\tLoss: 0.060802\n",
      "Train Epoch: 3 [33920/60000 (57%)]\tLoss: 0.062410\n",
      "Train Epoch: 3 [34560/60000 (58%)]\tLoss: 0.031893\n",
      "Train Epoch: 3 [35200/60000 (59%)]\tLoss: 0.071355\n",
      "Train Epoch: 3 [35840/60000 (60%)]\tLoss: 0.046244\n",
      "Train Epoch: 3 [36480/60000 (61%)]\tLoss: 0.043906\n",
      "Train Epoch: 3 [37120/60000 (62%)]\tLoss: 0.043662\n",
      "Train Epoch: 3 [37760/60000 (63%)]\tLoss: 0.060881\n",
      "Train Epoch: 3 [38400/60000 (64%)]\tLoss: 0.029288\n",
      "Train Epoch: 3 [39040/60000 (65%)]\tLoss: 0.116116\n",
      "Train Epoch: 3 [39680/60000 (66%)]\tLoss: 0.104301\n",
      "Train Epoch: 3 [40320/60000 (67%)]\tLoss: 0.163289\n",
      "Train Epoch: 3 [40960/60000 (68%)]\tLoss: 0.012309\n",
      "Train Epoch: 3 [41600/60000 (69%)]\tLoss: 0.196034\n",
      "Train Epoch: 3 [42240/60000 (70%)]\tLoss: 0.026732\n",
      "Train Epoch: 3 [42880/60000 (71%)]\tLoss: 0.011892\n",
      "Train Epoch: 3 [43520/60000 (72%)]\tLoss: 0.080649\n",
      "Train Epoch: 3 [44160/60000 (74%)]\tLoss: 0.141911\n",
      "Train Epoch: 3 [44800/60000 (75%)]\tLoss: 0.048305\n",
      "Train Epoch: 3 [45440/60000 (76%)]\tLoss: 0.290441\n",
      "Train Epoch: 3 [46080/60000 (77%)]\tLoss: 0.016071\n",
      "Train Epoch: 3 [46720/60000 (78%)]\tLoss: 0.030096\n",
      "Train Epoch: 3 [47360/60000 (79%)]\tLoss: 0.065505\n",
      "Train Epoch: 3 [48000/60000 (80%)]\tLoss: 0.034093\n",
      "Train Epoch: 3 [48640/60000 (81%)]\tLoss: 0.032176\n",
      "Train Epoch: 3 [49280/60000 (82%)]\tLoss: 0.019540\n",
      "Train Epoch: 3 [49920/60000 (83%)]\tLoss: 0.016631\n",
      "Train Epoch: 3 [50560/60000 (84%)]\tLoss: 0.225917\n",
      "Train Epoch: 3 [51200/60000 (85%)]\tLoss: 0.034671\n",
      "Train Epoch: 3 [51840/60000 (86%)]\tLoss: 0.040126\n",
      "Train Epoch: 3 [52480/60000 (87%)]\tLoss: 0.015072\n",
      "Train Epoch: 3 [53120/60000 (88%)]\tLoss: 0.199291\n",
      "Train Epoch: 3 [53760/60000 (90%)]\tLoss: 0.017611\n",
      "Train Epoch: 3 [54400/60000 (91%)]\tLoss: 0.046120\n",
      "Train Epoch: 3 [55040/60000 (92%)]\tLoss: 0.068309\n",
      "Train Epoch: 3 [55680/60000 (93%)]\tLoss: 0.124606\n",
      "Train Epoch: 3 [56320/60000 (94%)]\tLoss: 0.025565\n",
      "Train Epoch: 3 [56960/60000 (95%)]\tLoss: 0.019709\n",
      "Train Epoch: 3 [57600/60000 (96%)]\tLoss: 0.006606\n",
      "Train Epoch: 3 [58240/60000 (97%)]\tLoss: 0.036634\n",
      "Train Epoch: 3 [58880/60000 (98%)]\tLoss: 0.130711\n",
      "Train Epoch: 3 [59520/60000 (99%)]\tLoss: 0.047939\n",
      "\n",
      "Test set: Avg. loss: 0.1734, Accuracy: 9413/10000 (94%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "test()\n",
    "for epoch in range(1, n_epochs + 1):\n",
    "    train(epoch)\n",
    "    test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "tf_venv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
